#include "ai_google_net_image_division.h"

AI_Google_Net_Image_Division::AI_Google_Net_Image_Division(QWidget *parent)
    : CommonGraphicsView{parent}
{
    this->setWindowTitle("使用GoogleNet实现图像分割");
    modelPath = "/Users/yangwei/Documents/dev_tools/opencv_ai_model/googlenet/bvlc_googlenet.caffemodel";
    prototxtPath = "/Users/yangwei/Documents/dev_tools/opencv_ai_model/googlenet/bvlc_googlenet.prototxt";
    labelPath = "/Users/yangwei/Documents/dev_tools/opencv_ai_model/googlenet/synset_words.txt";
}


void AI_Google_Net_Image_Division::dropEvent(QDropEvent * event){
    path = event->mimeData()->urls().at(0).toLocalFile();
    qDebug()<<"图片路径:"<<path;
    showImageClassificatio(path.toStdString().c_str());
}

void AI_Google_Net_Image_Division::showImageClassificatio(const char * filePath){
    //加载原图
    Mat src = imread(filePath);
    if(src.empty()){
        qDebug()<<"图像路径为空";
        return;
    }
    imshow("src",src);
    //组装分类标签
    vector<String> labels = readLabels();//读取分类标签
    //载入网络模型
    Net net = readNetFromCaffe(prototxtPath.toStdString(),modelPath.toStdString());
    if(net.empty()){
        qDebug()<<"count not read net ....";
        return;
    }
    //对要被检测的图像进行预处理
    Mat inputBlob = blobFromImage(src,1.0,Size(224,224), Scalar(104, 117, 123));
    Mat prob;
    //给网络模型设置输入并传递到网络模型尾部
    for(int i=0;i<10;i++){//10层网络
        net.setInput(inputBlob,"data");
        prob = net.forward("prob");
    }
    //将prob转换为一行一列
    Mat probMat = prob.reshape(1,1);
    Point classNumber;
    double classProb;
    //找出模型识别到的最大概率
    minMaxLoc(probMat,NULL,&classProb,NULL,&classNumber);//找出最大值（最大概率）
    int classIndex = classNumber.x;
    //输出检测到的结果
    qDebug()<<"检测结果："<<labels.at(classIndex).c_str()<<"--possible:"<<classProb;
    //将结果会址出来
    putText(src, labels.at(classIndex), Point(20, 20), FONT_HERSHEY_SIMPLEX, 1.0, Scalar(0, 0, 255), 2, 8);
    imshow("src",src);

}

/**
 * 读取分类标签
 * @brief AI_Google_Net_Image_Division::readLabels
 * @return
 */
vector<String> AI_Google_Net_Image_Division::readLabels(){
    vector<String> classNames;//存放分类标签
    QFile file(labelPath);
    if (!file.open(QIODevice::ReadOnly | QIODevice::Text)) {
        qDebug()<<"could not open the file";
        exit(-1);
    }
    QString line;
    QTextStream in(&file);
    line = in.readLine();//读取一行放到字符串中
    while (!line.isEmpty()) {
        line = in.readLine();
        string name = line.toStdString();
        if (line.length()) {
            classNames.push_back(name.substr(name.find(' ') + 1));
        }
    }
    file.close();
    return classNames;

}
